package cn.lingyangwl.framework.tool.core.thread;

import cn.lingyangwl.framework.tool.core.exception.Assert;
import org.apache.commons.lang3.time.StopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;

/**
 * 异步任务执行器, 有如下功能
 * 1. 可以添加多个任务, 并发执行, 主线程会等待所有任务执行完毕才会往下执行
 * 2. 可以多线程查询数据
 *
 * @author shenguangyang
 * @since v1.0.0
 */
public class AsyncTaskExecutor {
    private static final Logger log = LoggerFactory.getLogger(AsyncTaskExecutor.class);
    private Executor executor;

    /**
     * 所有的任务列表都是并发执行的
     */
    private final List<Task> taskList = new CopyOnWriteArrayList<>();
    /**
     * 所有前置任务都是先串行运行的
     */
    private final List<Task> preTaskList = new CopyOnWriteArrayList<>();

    public static AsyncTaskExecutor init(Executor executor) {
        Assert.notNull(executor, "executor is null");
        AsyncTaskExecutor asyncTaskExecutor = new AsyncTaskExecutor();
        asyncTaskExecutor.executor = executor;
        return asyncTaskExecutor;
    }

    /**
     * 添加前置任务, 所有的前置任务都是串行执行的
     *
     * @param preProcessorIn 前置处理器输入数据
     * @param preProcessor   接受输入的数据, 并做处理, 然后返回一个结果
     * @param postProcessor  接受前置处理的结果进行再次加工
     * @param <T>            输入
     * @param <R>            前置处理的输出数据
     * @return this
     */
    public <T, R> AsyncTaskExecutor addPreTask(T preProcessorIn, Function<T, R> preProcessor, Consumer<R> postProcessor) {
        Assert.notNull(postProcessor, "postProcessor is null");
        Assert.notNull(preProcessor, "preProcessor is null");
        preTaskList.add(() -> {
            R postProcessIn = preProcessor.apply(preProcessorIn);
            postProcessor.accept(postProcessIn);
        });
        return this;
    }

    /**
     * 添加前置任务, 所有的前置任务都是串行执行的
     *
     * @param processorIn 前置处理器输入数据
     * @param processor   接收输入数据进行加工处理
     * @param <T>         输入
     * @return this
     */
    public <T> AsyncTaskExecutor addPreTask(T processorIn, Consumer<T> processor) {
        Assert.notNull(processor, "processor is null");
        preTaskList.add(() -> processor.accept(processorIn));
        return this;
    }

    /**
     * 添加任务
     *
     * @param preProcessorIn 前置处理器输入数据
     * @param preProcessor   接受输入的数据, 并做处理, 然后返回一个结果
     * @param postProcessor  接受前置处理的结果进行再次加工
     * @param <T>            输入
     * @param <R>            前置处理的输出数据
     * @return this
     */
    public <T, R> AsyncTaskExecutor addTask(T preProcessorIn, Function<T, R> preProcessor, Consumer<R> postProcessor) {
        Assert.notNull(preProcessor, "preProcessor is null");
        Assert.notNull(postProcessor, "postProcessor is null");
        taskList.add(() -> {
            R postProcessIn = preProcessor.apply(preProcessorIn);
            postProcessor.accept(postProcessIn);
        });
        return this;
    }

    /**
     * 添加任务
     *
     * @param processorIn 前置处理器输入数据
     * @param processor   接收输入数据进行加工处理
     * @param <T>         输入
     * @return this
     */
    public <T> AsyncTaskExecutor addTask(T processorIn, Consumer<T> processor) {
        Assert.notNull(processor, "processor is null");
        taskList.add(() -> processor.accept(processorIn));
        return this;
    }

    /**
     * 并发执行任务
     */
    public void execute() throws RuntimeException {
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        for (Task task : preTaskList) {
            task.runTask();
        }
        try {
            CountDownLatch taskLatch = new CountDownLatch(taskList.size());
            for (Task task : taskList) {
                executor.execute(() -> {
                    try {
                        task.runTask();
                    } catch (Exception e) {
                        log.error("error: {}", e.getMessage());
                    } finally {
                        taskLatch.countDown();
                    }
                });
            }
            taskLatch.await();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        stopWatch.stop();
        long time = stopWatch.getTime(TimeUnit.MILLISECONDS);
        log.debug("async task total time: {} ms", time);
    }

    public interface Task {
        void runTask();
    }
}
